{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1f2969c8",
   "metadata": {},
   "source": [
    "# P-Chat 🔒💬\n",
    "\n",
    "A privacy-focused bring-your-own-document (BYOD) solution that empowers you to leverage the power of LLMs to interact with your documents. Nothing is persisted, and it exists entirely in ephemeral memory.\n",
    "\n",
    "## Features\n",
    "- Parent-child chunking used to enrich the context\n",
    "- Chunk augmentation with some parent data for structured documents\n",
    "- Streamed responses for better user experience\n",
    "- Secure by design; no data is stored permanently\n",
    "- Uses locally-running Ollama for total privacy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df7609cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -qU langchain_ollama langchain_chroma langchain_community"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "144bdf7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import sys\n",
    "from pathlib import Path\n",
    "from enum import StrEnum\n",
    "\n",
    "import gradio as gr\n",
    "from langchain_core.documents import Document\n",
    "from langchain_text_splitters import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter\n",
    "from langchain_ollama import OllamaEmbeddings, ChatOllama\n",
    "from langchain.storage import InMemoryStore\n",
    "from langchain_chroma import Chroma\n",
    "from langchain_community.document_loaders import TextLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfdb143d",
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = logging.getLogger('rag')\n",
    "logger.setLevel(logging.DEBUG)\n",
    "\n",
    "if not logger.handlers:\n",
    "    handler = logging.StreamHandler(sys.stdout)\n",
    "    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n",
    "    handler.setFormatter(formatter)\n",
    "    logger.addHandler(handler)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e2f176b",
   "metadata": {},
   "source": [
    "## RAG Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78f2a554",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pretty_print(l: list[Document | tuple[Document, float]]):\n",
    "    for i,item in enumerate(l, start=1):\n",
    "        logger.debug('-' * 80 + '\\n')\n",
    "\n",
    "        if isinstance(item, tuple):\n",
    "            doc, score = item\n",
    "            logger.debug(f'{i}. characters: {len(doc.page_content)}\\n')\n",
    "            logger.debug(f'Score: {score}\\nMetadata: {doc.metadata}\\nContent: {doc.page_content}')\n",
    "        else:\n",
    "            logger.debug(f'{i}. characters: {len(item.page_content)}\\n')\n",
    "            logger.debug(f'Metadata: {item.metadata}\\nContent: {item.page_content}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42893f0b",
   "metadata": {},
   "source": [
    "### Indexing\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20ad0e80",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_id = 'qwen3:0.6b'\n",
    "embedding_model = 'nomic-embed-text:latest'\n",
    "\n",
    "embeddings = OllamaEmbeddings(model=embedding_model)\n",
    "model = ChatOllama(model=model_id, temperature=0.1)\n",
    "\n",
    "vectorstore = Chroma(\n",
    "    collection_name='p-chat',\n",
    "    embedding_function=embeddings,\n",
    ")\n",
    "docstore = InMemoryStore()\n",
    "\n",
    "class Metadata(StrEnum):\n",
    "    ID = 'id'\n",
    "    PARENT_ID = 'parent_id'\n",
    "    SOURCE = 'source'\n",
    "    FILE_TYPE = 'file_type'\n",
    "\n",
    "\n",
    "LOADER_MAPPING = {\n",
    "    '.md': TextLoader,\n",
    "    '.txt': TextLoader, \n",
    "}\n",
    "\n",
    "def load_documents(file_path: Path) -> list[Document]:\n",
    "    # p = Path(file_path)\n",
    "    extension = file_path.suffix\n",
    "    logger.info(f'Loading loader for {extension}')\n",
    "    loader_cls = LOADER_MAPPING.get(extension)\n",
    "\n",
    "    if loader_cls is None:\n",
    "        logger.warning(f'No loader configured for {extension}')\n",
    "        return []\n",
    "    \n",
    "    loader = loader_cls(file_path)\n",
    "    documents = loader.load()\n",
    "    logger.info(f'{len(documents)} loaded for {file_path.name}')\n",
    "\n",
    "    return documents\n",
    "\n",
    "\n",
    "def preprocess(documents: list[Document]) -> list[Document]:\n",
    "    # Perform any cleaning, etc.\n",
    "    import uuid\n",
    "\n",
    "    for doc in documents:\n",
    "        metadata = doc.metadata\n",
    "        shortened_source = metadata.get('source').split('/')[-1]\n",
    "\n",
    "        metadata[Metadata.ID] = str(uuid.uuid4())\n",
    "        metadata[Metadata.SOURCE] = shortened_source\n",
    "        metadata[Metadata.FILE_TYPE] = shortened_source.split('.')[-1]\n",
    "\n",
    "    return documents\n",
    "\n",
    "\n",
    "def index_document(file_path):\n",
    "    documents = load_documents(Path(file_path))\n",
    "    preprocessed_docs = preprocess(documents)\n",
    "    logger.debug([doc.metadata for doc in preprocessed_docs])\n",
    "\n",
    "    for doc in preprocessed_docs:\n",
    "        chunks = chunk_documents(doc)\n",
    "\n",
    "        vectorstore.add_documents(chunks)\n",
    "        docstore.mset([(doc.metadata.get(Metadata.ID) , doc)])\n",
    "\n",
    "\n",
    "def chunk_documents(parent: Document) -> list[Document]:\n",
    "    if parent.metadata.get(Metadata.FILE_TYPE) == '.md':\n",
    "        headers_to_split_on = [\n",
    "            ('#', 'employee_name'),\n",
    "            ('##', 'section'),\n",
    "            ('###', 'Header 3'),\n",
    "        ] \n",
    "        markdown_splitter = MarkdownHeaderTextSplitter(\n",
    "            headers_to_split_on=headers_to_split_on\n",
    "        )\n",
    "        chunks = markdown_splitter.split_text(parent.page_content)  \n",
    "    else:\n",
    "        text_splitter = RecursiveCharacterTextSplitter(\n",
    "            chunk_size=400,\n",
    "            chunk_overlap=80,\n",
    "            separators=['\\n\\n', '\\n', ' ', '']\n",
    "        )\n",
    "        chunks = text_splitter.split_text(parent.page_content)\n",
    "\n",
    "    children = []\n",
    "    parent_id = parent.metadata.get(Metadata.ID)\n",
    "    for i, chunk in enumerate(chunks, start=1):\n",
    "        if isinstance(chunk, Document):\n",
    "            metadata = {**parent.metadata, **chunk.metadata}\n",
    "            augmented_text = f'[Employee: {metadata.get('employee_name')}] '\n",
    "            content = augmented_text + chunk.page_content\n",
    "        else:\n",
    "            # chunk is a text\n",
    "            metadata = parent.metadata.copy()\n",
    "            content = chunk\n",
    "\n",
    "        metadata.update({\n",
    "            Metadata.ID: f'{parent_id}-{i}',\n",
    "            Metadata.PARENT_ID: parent_id,\n",
    "        })\n",
    "        children.append(Document(page_content=content, metadata=metadata))\n",
    "\n",
    "    logger.debug(f'Number chunks: {len(children)}, Parent ID: {parent_id}')\n",
    "    \n",
    "    return children"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a90db6ee",
   "metadata": {},
   "source": [
    "### LLM Interaction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2e15e99",
   "metadata": {},
   "outputs": [],
   "source": [
    "def retrieve_context(query) -> str:\n",
    "    results = vectorstore.similarity_search(query)\n",
    "    logger.info(f'Matching records: {len(results)}')\n",
    "    selected_parents = {}\n",
    "    for result in results:\n",
    "        parent_id = result.metadata.get('parent_id')\n",
    "        if parent_id in selected_parents:\n",
    "            continue\n",
    "\n",
    "        parents = docstore.mget([parent_id])\n",
    "        selected_parents[parent_id] = parents[0]\n",
    "\n",
    "    logger.info(f'Selected documents for query: {query} ids:{selected_parents.keys()}')\n",
    "    context = '\\n\\n'.join([doc.page_content for _,doc in selected_parents.items() if doc is not None])\n",
    "\n",
    "    return context\n",
    "\n",
    "        \n",
    "def ask(message, history):\n",
    "    context = retrieve_context(message)\n",
    "    prompt = f'''\n",
    "    You are helpful assistant that answers a question based on the provided context.\n",
    "    If the context is not helpful to you in answering the question, say so.\n",
    "    Be concise with your responses.\n",
    "\n",
    "    Context:\n",
    "    {context}\n",
    "    '''\n",
    "\n",
    "    messages = [\n",
    "        ('system', prompt),\n",
    "        ('user', message)\n",
    "    ]\n",
    "\n",
    "    stream = model.stream(messages)\n",
    "    response_text = ''\n",
    "\n",
    "    for chunk in stream:\n",
    "        response_text += chunk.content or ''\n",
    "        if not response_text:\n",
    "            continue\n",
    "\n",
    "        yield response_text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3e632dc-9e87-4510-9fcd-aa699c27e82b",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Gradio UI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3d68a74",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat(message, history):\n",
    "    if message is None:\n",
    "        return ''\n",
    "\n",
    "    text_input = message.get('text', '')\n",
    "    files_uploaded = message.get('files', [])\n",
    "    \n",
    "    latest_file_path = files_uploaded[-1] if files_uploaded else None\n",
    "    if latest_file_path:\n",
    "        index_document(latest_file_path)\n",
    "\n",
    "\n",
    "    if not text_input:\n",
    "        yield '✅ Indexed document'\n",
    "        return\n",
    "\n",
    "    for chunk in ask(text_input, history):\n",
    "        yield chunk\n",
    "\n",
    "title = 'P-Chat 🔒💬'\n",
    "with gr.Blocks(title=title, fill_height=True) as ui:\n",
    "    gr.Markdown(f'# {title}')\n",
    "    gr.Markdown('## Privacy-focused bring-your-own-document (BYOD) solution 🤫.')\n",
    "\n",
    "    gr.ChatInterface(\n",
    "        fn=chat,\n",
    "        type='messages',\n",
    "        textbox=gr.MultimodalTextbox(file_types=['text', '.txt', '.md'], autofocus=True),\n",
    "    )\n",
    "\n",
    "ui.launch(debug=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
